#!/usr/bin/env python3

import argparse
import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], ('..')))

from colorization.colorization_model import ColorizationModel
from colorization.modules.colorization_network import ColorizationNetwork
from colorization.util.argparse import nice_help_formatter


USAGE = """\
convert_weights [-h|--help]
                    {vgg,deeplab}
                    OUTFILE
                    --weights WEIGHTS
                    [--proto PROTO]
"""


if __name__ == '__main__':
    # parse command line arguments
    parser = argparse.ArgumentParser(
        formatter_class=nice_help_formatter(),
        usage=USAGE,
        description=str("Convert pretrained Caffe/Tensorflow weights into a "
                        "PyTorch readable format."))

    parser.add_argument('base_network',
                        choices=['vgg', 'deeplab'])

    parser.add_argument('outfile',
                        metavar='OUTFILE',
                        help=str("file to which converted PyTorch weights are "
                                 "to be written"))

    parser.add_argument('--weights',
                        required=True,
                        help=str("Binary model weights, should be a .caffefile "
                                 "if BACKEND_NETWORK is 'vgg' or a tensorflow "
                                 "checkpoint if BACKEND_NETWORK is 'deeplab'"))

    parser.add_argument('--proto',
                        help=str("Caffe .prototxt file, only needed if "
                                 "BACKEND_NETWORK is 'vgg'"))

    args = parser.parse_args()

    # convert weights
    network = ColorizationNetwork(base_network=args.base_network, device='cpu')

    if args.base_network == 'vgg':
        if args.proto is None:
            err = "--proto must be given when BACKEND_NETWORK is 'vgg'"
            print(err, file=sys.stderr)
            sys.exit(1)

        network.base_network.init_from_caffe(args.proto, args.weights)

    elif args.base_network == 'deeplab':
        network.base_network.init_from_tensorflow(args.weights)

    model = ColorizationModel(network)
    model.save_checkpoint(args.outfile)
